package cn.turboinfo.dongying.api.provider.common.config

import com.fasterxml.jackson.core.JsonGenerator
import com.fasterxml.jackson.core.JsonParser
import com.fasterxml.jackson.core.JsonToken
import com.fasterxml.jackson.databind.*
import com.fasterxml.jackson.databind.module.SimpleModule
import mu.KotlinLogging
import net.sunshow.toolkit.core.base.enums.BaseEnum
import org.apache.commons.beanutils.BeanUtils
import org.springframework.beans.factory.support.BeanDefinitionBuilder
import org.springframework.beans.factory.support.BeanDefinitionRegistry
import org.springframework.context.EnvironmentAware
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar
import org.springframework.core.env.Environment
import org.springframework.core.type.AnnotationMetadata
import org.springframework.core.type.filter.AssignableTypeFilter
import org.springframework.util.ReflectionUtils
import java.io.IOException

class CustomEnumJacksonBeanDefinitionRegistrar : ImportBeanDefinitionRegistrar, EnvironmentAware {

    private lateinit var environment: Environment

    private val basePackages = arrayOf(
        "net.sunshow.toolkit.core.base.enums",
        "**.entity.**.enumeration",
    )

    override fun setEnvironment(environment: Environment) {
        this.environment = environment
    }

    override fun registerBeanDefinitions(importingClassMetadata: AnnotationMetadata, registry: BeanDefinitionRegistry) {
        val provider = ClassPathScanningCandidateComponentProvider(false, environment)

        // 查找所有 BaseEnum 的子类
        provider.addIncludeFilter(AssignableTypeFilter(BaseEnum::class.java))

        val module = SimpleModule()

        val serializer = CustomValueEnumSerializer()

        for (basePackage in basePackages) {
            for (beanDefinition in provider.findCandidateComponents(basePackage)) {
                @Suppress("UNCHECKED_CAST")
                val beanClass = Class.forName(beanDefinition.beanClassName) as Class<Any>
                module.addSerializer(beanClass, serializer)
                module.addDeserializer(beanClass, CustomValueEnumDeserializer(beanClass))
            }
        }

        // 注册供后续处理
        registry.registerBeanDefinition(
            "defaultCustomEnumJacksonModule",
            BeanDefinitionBuilder
                .genericBeanDefinition(
                    Module::class.java
                ) {
                    module
                }
                .beanDefinition
        )


    }

    class CustomValueEnumSerializer : JsonSerializer<Any>() {

        private val logger = KotlinLogging.logger {}

        @Throws(IOException::class)
        override fun serialize(value: Any, gen: JsonGenerator, serializers: SerializerProvider) {
            try {
                gen.writeNumber(BeanUtils.getProperty(value, "value").toInt())
            } catch (e: Exception) {
                logger.error(e) {
                    "自定义enum序列化出错, value=$value"
                }
            }
        }

    }

    class CustomValueEnumDeserializer(private val type: Class<Any>) :
        JsonDeserializer<Any>() {

        private val logger = KotlinLogging.logger {}

        override fun deserialize(p: JsonParser, ctxt: DeserializationContext): Any? {
            val currentToken = p.currentToken
            var value: Int? = null
            if (currentToken == JsonToken.VALUE_NUMBER_INT) {
                value = p.intValue
            } else if (currentToken == JsonToken.VALUE_STRING) {
                val s = p.valueAsString
                try {
                    value = Integer.valueOf(s)
                } catch (e: NumberFormatException) {
                    logger.error(e) {
                        "解析自定义enum出错, 无效的值: $s"
                    }
                }
            }
            if (value != null) {
                val getMethod = ReflectionUtils.findMethod(
                    type, "get",
                    Int::class.javaPrimitiveType
                )
                if (getMethod == null) {
                    logger.error { "未找到反序列化方法" }
                    return null
                }
                return ReflectionUtils.invokeMethod(getMethod, null, value)
            }
            return null
        }

    }
}